Modeling of Duration Estimation under Memory Load

1 Modeling of duration estimation under memory load

1.1 Model structure

1.1.1 1. Duration Encoding

Given that scalar property is the key feature of duration estimation, the sensory measure (S) is assumed to be a log-scale of physical duration (D):

\[S = \log(D) + \epsilon\]

where \(\epsilon\) is the noise term.

The encoding is influenced by memory load (M) and the encoding function is assumed to be Gaussian - \(S_{wm} \sim N(\mu_{wm}, \sigma_{wm}^2)\), where \(\mu_{wm}\) and \(\sigma_{wm}^2\) are the mean and variance of the encoding function.

\[\mu_{wm} = \log{D} - k_s\log(M)\]

\[\sigma_{wm}^2 = \sigma_s^2 (1+ l_s\cdot \log(M))\]

where \(M\) represents the set size of the working memory task, \(k_s\) and \(l_s\) are the scaling factors, and \(\sigma_s^2\) is the variance of the sensory measure.

Note on Gap Effect: When the retention interval (gap) between encoding and reproduction is manipulated, the variance term becomes:

\[\sigma_{wm}^2 = \sigma_s^2 (1+ l_s\cdot \log(M + T_{gap} - 1))\]

where \(T_{gap}\) is the retention interval duration.

1.1.2 2. Bayesian integration

Given all trials were randomly intermixed, the posterior probability of the duration estimate is given by \(N(\mu_{post}, \sigma_{post}^2)\), where \(\mu_{post}\) and \(\sigma_{post}^2\) are the mean and variance of the posterior distribution. According to the Bayesian integration theory, the posterior distribution is a weighted sum of the sensory measure and the prior distribution:

\[\mu'_{post} = (1-w_p)\mu_{wm} + w_p\mu_{prior}\]

where \(w_p = \frac{1/\sigma_{prior}^2}{1/\sigma_{wm}^2 + 1/\sigma_{prior}^2}\).

1.1.3 3. Duration Reproduction

Having to maintain a number (‘load’) of items in working memory during the duration-reproduction phase would influence the monitoring of the elapsed time of the reproduction. Conceiving of the monitoring of the sensory ‘elapsed time’ (i.e., the time from the starting key press onwards), \(\mu_{elapsed}\), as involving the counting of ‘clock ticks’ by an accumulator, lapses and diversion of attention to other, non-temporal processes would result in some ticks being lost, or missed, in the count. Here we assume that the loss is proportional to the memory load, that is: \(\mu_{elapsed}-k_rM\), where \(k_r\) is a scaling factor. The decision to then release the reproduction key is determined by the comparison between the perceived ‘elapsed time’ and the memorized target duration, \(\mu'_{post}\):

\[|\mu'_{post}-(\mu_{elapsed}-k_rM)|< \delta \]

which is equivalent to comparing the sensory elapsed time to \(\mu'_{post}+k_rM\).

We transfer logscal to linear scale:

\[\mu_r = e^{\mu'_{post}+k_r\log(M) + \sigma_{post}^2/2}\]

\[\sigma_r^2 = (e^{\sigma_{post}^2} -1)e^{2(\mu_{post}+k_r\log(M))+\sigma_{post}^2}\]

where \(\mu_r\) and \(\sigma_r^2\) are the mean and variance of the reproduction distribution.

The impact of non-temporal noise will be reduced as duration increases, and we assume:

\[\sigma_{observed}^2 = \sigma_r^2 + \sigma_{non-temporal}^2/D\]

where \(\sigma_{non-temporal}^2\) is the variance of the non-temporal noise.

1.2 Setup

Code
import arviz as az
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
import seaborn as sns
import datetime as dt
import os
import pytensor.tensor as pt

print('Last updated on:', dt.datetime.now())
Last updated on: 2025-11-20 21:11:36.754371

1.3 Data Loading

Code
# get the parent directory
cpath = os.path.abspath(os.path.join(os.getcwd(), '..'))
# import raw data from csv file cpath + data/AllData.csv
# Note: Adjust path if running from analysis/ directory
if not os.path.exists(os.path.join(cpath, 'data', 'AllValidData.csv')):
    # Try current directory if running as script
    cpath = os.getcwd() 
    if not os.path.exists(os.path.join(cpath, 'data', 'AllValidData.csv')):
         # Try one level up
         cpath = os.path.dirname(os.getcwd())

expData = pd.read_csv(os.path.join(cpath, 'data', 'AllValidData.csv'))

# map the column Exp to ExpName
expData['ExpName'] = expData['Exp'].map({
    'Exp1': 'Baseline', 
    'Exp2': 'Encoding', 
    'Exp3': 'Reproduction', 
    'Exp4': 'Both', 
    'Exp5': 'Both_gap'
})

# Filter out Exp6 if it exists (though map above handles it by NaN if not mapped, let's be explicit)
expData = expData.dropna(subset=['ExpName'])

print(expData.ExpName.unique())
expData.head()
['Baseline' 'Encoding' 'Reproduction' 'Both' 'Both_gap']
WMSize DurLevel TPresent NT NSub curDur repDur WMRP valid stdRepDur Exp gap log_RP log_dur Gap ExpName
0 3 1 1 1 1 0.5 0.742038 2 1 0.203701 Exp1 1 -0.298354 -0.693147 0.5 Baseline
1 3 4 1 2 1 1.4 1.301903 1 1 0.297183 Exp1 1 0.263827 0.336472 0.5 Baseline
2 3 3 1 3 1 1.1 0.805957 1 1 0.201791 Exp1 1 -0.215724 0.095310 0.5 Baseline
3 3 5 2 4 1 1.7 1.253941 2 1 0.168043 Exp1 1 0.226292 0.530628 0.5 Baseline
4 3 4 2 5 1 1.4 0.709891 2 1 0.297183 Exp1 1 -0.342644 0.336472 0.5 Baseline

1.4 Model Definition

Code
def hModel(dat, constrain):
    """
    constrain: list of 4 integers [k_s, l_s, gap_effect, k_r]
    1 means parameter is free, 0 means fixed to 0 (or default)
    """
    # prepare the data
    subid = dat.NSub - 1 # starting from index 0
    nsub = len(dat.NSub.unique())  # number of subject
    
    # log-transformed working memory size
    wm_idx = np.log(dat.WMSize.to_numpy())
    durs = dat.curDur.to_numpy()
    repDur = dat.repDur.to_numpy()
    lnRepDur = np.log(repDur)
    lnDur = np.log(durs)
    
    # Gap effect handling
    # If gap_effect (constrain[2]) is 1, we use the gap in the variance term
    # Original logic: wm_sig = np.log(dat.WMSize.to_numpy() + dat.gap.to_numpy()-1)
    if constrain[2] == 1:
        wm_sig = np.log(dat.WMSize.to_numpy() + dat.gap.to_numpy() - 1)
    else:
        wm_sig = np.log(dat.WMSize.to_numpy())

    niter = 2000
    
    with pm.Model() as WMmodel:
        # auxiliary variables (non-centered parameterization, individual level)
        var_s = pm.HalfNormal('var_s', 1, shape = nsub) # sensory noise
        epsilon_k = pm.HalfNormal('epsilon_k', 3, shape = nsub)
        
        # k_s: working memory coeff. on ticks (Encoding Mean)
        if constrain[0] == 1:
            k_s0 = pm.HalfNormal('k_s', 1) 
            k_s =  k_s0 * epsilon_k 
        else:
            k_s = np.zeros(nsub)
            
        # l_s: working memory impacts on variance (Encoding Variance)
        if constrain[1] == 1:
            l_s0 = pm.HalfNormal('l_s', 1) 
            l_s = l_s0 * epsilon_k
        else:
            l_s = np.zeros(nsub)
                  
        # k_r: working memory influence on reproduction (Reproduction Mean)
        if constrain[3] == 1:
            k_r0 = pm.Normal('k_r', 0, sigma = 1) 
            k_r =  k_r0 * epsilon_k
        else:
            k_r = np.zeros(nsub) 

        # prior (internal log encoding)
        epsilon = pm.HalfNormal('epsilon', 3, shape = nsub)
        mu_p = pm.Normal('mu_p', 0, sigma = 1, shape = nsub) # in log space
        var_p0 = pm.HalfNormal('var_p', 1) # in log-space
        var_p = var_p0 * epsilon

        sig_n0 = pm.HalfNormal('sig_n', 1)
        var_n = sig_n0 * epsilon
        
        # sensory measurement with log encoding + ticks loss by memory task
        # D = D_s(1+k*wm) -> log(D) = log(D_s) + log(1+k*wm) ~ log(D_s) + k*wm (approx)
        # Here using: mu_wm = log(D) - k_s * log(WM) ?? 
        # Wait, original code: D_s = lnDur - k_s[subid] * wm_idx
        # wm_idx is log(WMSize). So it is mu = log(D) - k * log(M)
        D_s = lnDur - k_s[subid] * wm_idx

        var_wm = var_s[subid] * (1 + l_s[subid] * wm_sig)
        
        # integration with prior
        w_p = var_wm / (var_p[subid] + var_wm)
        
        # posterior
        u_x = (1-w_p)*D_s + w_p * mu_p[subid] # posterior mean

        var_x = var_wm * var_p[subid] / (var_wm + var_p[subid])  # posterior variance
        
        # reproduction
        # reproduced duration
        u_x1 = u_x + k_r[subid] * wm_idx
        u_r = np.exp(u_x1 + var_x/2) # reproduced duration with corrupted from memory task
        
        # reproduced sigmas
        # sig_r = sqrt( (exp(var_x)-1)*exp(2*u_x1 + var_x) + var_n/D )
        # Note: Original code used 'durs' (vector) in the last term.
        sig_r = np.sqrt((np.exp(var_x)-1)*np.exp(2*(u_x1) + var_x) + var_n[subid]/durs)

        # Data likelihood 
        resp_like = pm.Normal('resp_like', mu = u_r, sigma = sig_r, observed = repDur)
        
        # Sampling
        # Using smaller tune/draws for testing, increase for production
        trace = pm.sample(draws=niter, tune=1000, progressbar=True, return_inferencedata=True, target_accept=0.85, idata_kwargs={'log_likelihood': True})
                  
    return trace, WMmodel

1.5 Helper Functions

Code
def getPosteriorSummary(posterior_samples, dat):
    # create a dataframe for mean posterior samples
    posterior_df = pd.DataFrame(columns = ['NSub', 'curDur', 'WMSize', 'mPred', 'sdPred'])
    
    nsub = len(dat.NSub.unique())
    nDur = len(dat.curDur.unique())
    nWM = len(dat.WMSize.unique())
    
    # Check if gap column exists (for Both_gap experiment)
    has_gap = 'gap' in dat.columns
    
    # Pre-calculate unique values to speed up
    unique_subs = dat.NSub.unique()
    unique_durs = dat.curDur.unique()
    unique_wms = dat.WMSize.unique()
    unique_gaps = dat.gap.unique() if has_gap else [None]
    
    records = []
    
    # This loop is slow, but keeping original logic for consistency for now
    # Optimized slightly by using list collection
    for subid in unique_subs:
        for curDur in unique_durs:
            for WMSize in unique_wms:
                for gap in unique_gaps:
                    if has_gap:
                        idx = (dat.NSub == subid) & (dat.curDur == curDur) & (dat.WMSize == WMSize) & (dat.gap == gap)
                    else:
                        idx = (dat.NSub == subid) & (dat.curDur == curDur) & (dat.WMSize == WMSize)
                    
                    if np.sum(idx) > 0:
                        mPred = np.mean(posterior_samples[:, idx])
                        sdPred = np.std(posterior_samples[:, idx])
                        record = {
                            'NSub': subid, 
                            'curDur': curDur, 
                            'WMSize': WMSize, 
                            'mPred': mPred, 
                            'sdPred': sdPred
                        }
                        if has_gap:
                            record['gap'] = gap
                        records.append(record)
    
    posterior_df = pd.DataFrame(records)
    
    # calculate the observed means and standard deviations
    group_cols = ['NSub', 'curDur', 'WMSize']
    if has_gap:
        group_cols.append('gap')
    
    mdat = dat.groupby(group_cols).agg(
        repDur_mean=('repDur', 'mean'),
        repDur_std=('repDur', 'std')
    ).reset_index()

    # Perform the merge
    mdat = pd.merge(mdat, posterior_df, on=group_cols)

    # Calculate errors and CVs
    mdat['repErr'] = mdat['repDur_mean'] - mdat['curDur']
    mdat['predErr'] = mdat['mPred'] - mdat['curDur']
    mdat['repCV'] = mdat['repDur_std'] / mdat['repDur_mean']
    mdat['predCV'] = mdat['sdPred'] / mdat['mPred']
    return mdat

def runModel(dat, constrain, model_name, output_dir):
    print(f"Running {model_name}...")
    trace, model = hModel(dat, constrain)
    
    with model:
        ppc = pm.sample_posterior_predictive(trace, var_names=['resp_like'])
        posterior_resp = ppc.posterior_predictive['resp_like'].to_numpy() 
        # reshape: (chains*draws, observations)
        posterior_samples = posterior_resp.reshape(-1, posterior_resp.shape[2]) 
        
        # Ensure log_likelihood is present
        if not hasattr(trace, 'log_likelihood'):
            pm.compute_log_likelihood(trace) 
        
    # get summary
    mdat = getPosteriorSummary(posterior_samples, dat)
    para = az.summary(trace)

    # save
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    az.to_netcdf(trace, os.path.join(output_dir, f'{model_name}_trace.nc'))
    mdat.to_csv(os.path.join(output_dir, f'{model_name}_mdat.csv'))
    para.to_csv(os.path.join(output_dir, f'{model_name}_para.csv'))
    
    return trace, model, para, mdat

1.6 Main Analysis Loop

Code
experiments = ['Baseline', 'Encoding', 'Reproduction', 'Both', 'Both_gap']

# Define models and their constraints: [k_s, l_s, gap_effect, k_r]
# k_s: Encoding Mean (Load)
# l_s: Encoding Variance (Load)
# gap_effect: Encoding Variance (Gap)
# k_r: Reproduction Mean (Load)

models_config = {
    'NULL': [0, 0, 0, 0],
    'EncodingOnly': [1, 1, 0, 0],
    'ReproductionOnly': [0, 0, 0, 1],
    'FreeParameters': [1, 1, 1, 1]
}

# Experiment-wise constraints
# Baseline: No load effect -> NULL
# Encoding: Load during encoding -> EncodingOnly
# Reproduction: Load during reproduction -> ReproductionOnly
# Both: Load during both -> FreeParameters (or Both model?)
# Both_gap: Load + Gap -> FreeParameters + Gap effect

experiment_wise_map = {
    'Baseline': [0, 0, 0, 0],
    'Encoding': [1, 1, 0, 0],
    'Reproduction': [0, 0, 0, 1],
    'Both': [1, 1, 0, 1],
    'Both_gap': [1, 1, 1, 1]
}

OUTPUT_PATH = os.path.join(cpath, 'data', 'model_comparison_results')
if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

# To store traces for comparison
all_traces = {exp: {} for exp in experiments}

for exp in experiments:
    print(f"\n=== Processing Experiment: {exp} ===")
    dat = expData[expData['ExpName'] == exp].copy()
    
    # 1. Run Standard Models
    for m_name, constraints in models_config.items():
        
        # Construct full model name for saving
        full_model_name = f"{m_name}_{exp}"
        save_dir = os.path.join(cpath, 'data', m_name)
        
        # Check if already exists to avoid re-running (optional, but good for dev)
        trace_path = os.path.join(save_dir, f'{exp}_trace.nc')
        if os.path.exists(trace_path):
            print(f"Loading existing trace for {full_model_name}")
            trace = az.from_netcdf(trace_path)
        else:
            trace, _, _, _ = runModel(dat, constraints, exp, save_dir)
            
        all_traces[exp][m_name] = trace

    # 2. Run Experiment-wise Model
    exp_constraints = experiment_wise_map[exp]
    m_name = 'Experimentwise'
    save_dir = os.path.join(cpath, 'data', m_name)
    trace_path = os.path.join(save_dir, f'{exp}_trace.nc')
    
    if os.path.exists(trace_path):
        print(f"Loading existing trace for {m_name}_{exp}")
        trace = az.from_netcdf(trace_path)
    else:
        trace, _, _, _ = runModel(dat, exp_constraints, exp, save_dir)
    
    all_traces[exp][m_name] = trace

=== Processing Experiment: Baseline ===
Loading existing trace for NULL_Baseline
Loading existing trace for EncodingOnly_Baseline
Loading existing trace for ReproductionOnly_Baseline
Loading existing trace for FreeParameters_Baseline
Loading existing trace for Experimentwise_Baseline

=== Processing Experiment: Encoding ===
Loading existing trace for NULL_Encoding
Loading existing trace for EncodingOnly_Encoding
Loading existing trace for ReproductionOnly_Encoding
Loading existing trace for FreeParameters_Encoding
Loading existing trace for Experimentwise_Encoding

=== Processing Experiment: Reproduction ===
Loading existing trace for NULL_Reproduction
Loading existing trace for EncodingOnly_Reproduction
Loading existing trace for ReproductionOnly_Reproduction
Loading existing trace for FreeParameters_Reproduction
Loading existing trace for Experimentwise_Reproduction

=== Processing Experiment: Both ===
Loading existing trace for NULL_Both
Loading existing trace for EncodingOnly_Both
Loading existing trace for ReproductionOnly_Both
Loading existing trace for FreeParameters_Both
Loading existing trace for Experimentwise_Both

=== Processing Experiment: Both_gap ===
Loading existing trace for NULL_Both_gap
Loading existing trace for EncodingOnly_Both_gap
Loading existing trace for ReproductionOnly_Both_gap
Loading existing trace for FreeParameters_Both_gap
Loading existing trace for Experimentwise_Both_gap

1.7 Model Comparison

Code
def check_model_reliability(traces_dict, ic='loo', pareto_k_thresh=0.7):
    try:
        compare_df = az.compare(traces_dict, ic=ic, scale='deviance')
    except Exception as e:
        print(f"Error computing {ic}: {e}")
        return None

    compare_df['reliable'] = True
    
    if ic.lower() == 'loo':
        for name, idata in traces_dict.items():
            loo = az.loo(idata)
            if np.any(loo.pareto_k > pareto_k_thresh):
                compare_df.loc[name, 'reliable'] = False
                print(f"LOO warning: {name} has Pareto k > {pareto_k_thresh}")
    
    return compare_df

summary_records = []

for exp in experiments:
    print(f"\n=== Comparing models for experiment: {exp} ===")
    traces = all_traces[exp]
    
    if len(traces) < 2:
        print("Not enough models to compare.")
        continue
        
    # LOO Comparison
    comp_loo = check_model_reliability(traces, ic='loo')
    if comp_loo is not None:
        comp_loo['experiment'] = exp
        comp_loo['ic_type'] = 'loo'
        comp_loo['model'] = comp_loo.index
        summary_records.append(comp_loo)
        
        # Save individual comparison
        comp_loo.to_csv(os.path.join(OUTPUT_PATH, f'{exp}_loo_comparison.csv'))
        print(comp_loo[['rank', 'elpd_loo', 'elpd_diff', 'weight', 'reliable']])

# Combine all results
if summary_records:
    full_summary = pd.concat(summary_records, ignore_index=True)
    full_summary.to_csv(os.path.join(OUTPUT_PATH, 'all_models_comparison_summary.csv'), index=False)
    print("\nSaved full summary to all_models_comparison_summary.csv")

=== Comparing models for experiment: Baseline ===
/Users/strongway/miniconda3/envs/pymc_env/lib/python3.12/site-packages/arviz/stats/stats.py:792: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
/Users/strongway/miniconda3/envs/pymc_env/lib/python3.12/site-packages/arviz/stats/stats.py:792: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
/Users/strongway/miniconda3/envs/pymc_env/lib/python3.12/site-packages/arviz/stats/stats.py:792: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
LOO warning: FreeParameters has Pareto k > 0.7
/Users/strongway/miniconda3/envs/pymc_env/lib/python3.12/site-packages/arviz/stats/stats.py:792: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
LOO warning: Experimentwise has Pareto k > 0.7
                  rank     elpd_loo  elpd_diff        weight  reliable
Experimentwise       0 -1987.241177   0.000000  5.091499e-01     False
EncodingOnly         1 -1987.093234   0.147944  4.908501e-01      True
NULL                 2 -1986.406093   0.835084  0.000000e+00      True
FreeParameters       3 -1985.045273   2.195904  1.657357e-16     False
ReproductionOnly     4 -1984.441453   2.799724  7.771561e-16      True

=== Comparing models for experiment: Encoding ===
/Users/strongway/miniconda3/envs/pymc_env/lib/python3.12/site-packages/arviz/stats/stats.py:792: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
/Users/strongway/miniconda3/envs/pymc_env/lib/python3.12/site-packages/arviz/stats/stats.py:792: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
LOO warning: EncodingOnly has Pareto k > 0.7
                  rank     elpd_loo  elpd_diff        weight  reliable
EncodingOnly         0  1015.229238   0.000000  8.129124e-01     False
Experimentwise       1  1015.340878   0.111641  0.000000e+00      True
FreeParameters       2  1015.657215   0.427977  1.179410e-14      True
ReproductionOnly     3  1061.130252  45.901014  1.870876e-01      True
NULL                 4  1081.870734  66.641497  0.000000e+00      True

=== Comparing models for experiment: Reproduction ===
                  rank     elpd_loo  elpd_diff        weight  reliable
FreeParameters       0 -1182.343105   0.000000  9.224427e-01      True
ReproductionOnly     1 -1181.256203   1.086903  4.857846e-18      True
Experimentwise       2 -1180.893746   1.449359  6.128422e-18      True
NULL                 3 -1130.425169  51.917936  7.755726e-02      True
EncodingOnly         4 -1128.082579  54.260526  2.220446e-16      True

=== Comparing models for experiment: Both ===
/Users/strongway/miniconda3/envs/pymc_env/lib/python3.12/site-packages/arviz/stats/stats.py:792: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
/Users/strongway/miniconda3/envs/pymc_env/lib/python3.12/site-packages/arviz/stats/stats.py:792: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
/Users/strongway/miniconda3/envs/pymc_env/lib/python3.12/site-packages/arviz/stats/stats.py:792: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
/Users/strongway/miniconda3/envs/pymc_env/lib/python3.12/site-packages/arviz/stats/stats.py:792: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
/Users/strongway/miniconda3/envs/pymc_env/lib/python3.12/site-packages/arviz/stats/stats.py:792: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
/Users/strongway/miniconda3/envs/pymc_env/lib/python3.12/site-packages/arviz/stats/stats.py:792: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
LOO warning: NULL has Pareto k > 0.7
/Users/strongway/miniconda3/envs/pymc_env/lib/python3.12/site-packages/arviz/stats/stats.py:792: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
LOO warning: EncodingOnly has Pareto k > 0.7
/Users/strongway/miniconda3/envs/pymc_env/lib/python3.12/site-packages/arviz/stats/stats.py:792: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
LOO warning: ReproductionOnly has Pareto k > 0.7
/Users/strongway/miniconda3/envs/pymc_env/lib/python3.12/site-packages/arviz/stats/stats.py:792: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
LOO warning: FreeParameters has Pareto k > 0.7
/Users/strongway/miniconda3/envs/pymc_env/lib/python3.12/site-packages/arviz/stats/stats.py:792: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
LOO warning: Experimentwise has Pareto k > 0.7
                  rank     elpd_loo  elpd_diff        weight  reliable
FreeParameters       0 -1112.641838   0.000000  6.206683e-01     False
Experimentwise       1 -1112.397604   0.244234  2.507610e-15     False
ReproductionOnly     2 -1110.315077   2.326761  2.963072e-01     False
EncodingOnly         3 -1109.525831   3.116007  0.000000e+00     False
NULL                 4 -1077.741341  34.900497  8.302450e-02     False

=== Comparing models for experiment: Both_gap ===
                  rank     elpd_loo   elpd_diff        weight  reliable
Experimentwise       0 -3957.246894    0.000000  7.748865e-01      True
FreeParameters       1 -3956.313858    0.933036  2.472375e-13      True
EncodingOnly         2 -3928.042296   29.204597  0.000000e+00      True
ReproductionOnly     3 -3906.740600   50.506294  1.884825e-01      True
NULL                 4 -3562.832144  394.414749  3.663096e-02      True

Saved full summary to all_models_comparison_summary.csv

1.8 Visualization

Code
# load summary records
full_summary = pd.read_csv(os.path.join(OUTPUT_PATH, 'all_models_comparison_summary.csv'))
# Plotting Delta IC Heatmap
plt.figure(figsize=(10, 6))

# Pivot data: Rows=Model, Cols=Experiment, Values=elpd_diff
heatmap_data = full_summary.pivot(index='model', columns='experiment', values='elpd_diff')

# Reorder columns if needed to match experiment order
exp_order = [e for e in experiments if e in heatmap_data.columns]
heatmap_data = heatmap_data[exp_order]

# Create heatmap
# vmin=0, vmax=5 ensures the color range is 0-5. Values >5 will be the same color as 5 (saturated).
# cmap="viridis_r" makes 0 (best) bright/distinct and high values dark/different. 
# Or "Reds" where 0 is white and high is red. Let's use "viridis_r" or "RdYlBu_r" (Blue=Low/Good, Red=High/Bad)
# Actually "viridis_r" (Yellow=Low/Good, Purple=High/Bad) is often good. 
# Let's stick to a sequential palette like "Reds" where 0 is light and high is dark red, 
# or "Blues_r" where 0 is dark blue (good) and high is light.
# Common for "difference" is 0=Good.
# Let's use 'YlOrRd' where 0 is yellow (low diff) and 5+ is Red (high diff).

sns.heatmap(heatmap_data, annot=True, fmt=".1f", cmap="YlOrRd", vmin=0, vmax=5, linewidths=.5)

plt.title('Model Comparison (Delta LOO)\n(Values > 5 indicate significantly worse fit)')
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_PATH, 'model_comparison_heatmap.png'))
plt.show()

1.9 Parameter Visualization

Code
def visualize_model_parameters(model_name, params=['k_s', 'l_s', 'k_r'], experiments=None):
    """
    Visualize estimated parameters across experiments for a given model.
    
    Parameters:
    -----------
    model_name : str
        Name of the model (e.g., 'FreeParameters', 'EncodingOnly')
    params : list
        List of parameter names to visualize (default: ['k_s', 'l_s', 'k_r'])
    experiments : list
        List of experiments to include (default: all experiments)
    """
    if experiments is None:
        experiments = ['Baseline', 'Encoding', 'Reproduction', 'Both', 'Both_gap']
    
    # Collect parameter estimates
    param_data = {p: {'mean': [], 'sd': [], 'exp': []} for p in params}
    
    for exp in experiments:
        para_file = os.path.join(cpath, 'data', model_name, f'{exp}_para.csv')
        
        if not os.path.exists(para_file):
            print(f"Warning: {para_file} not found, skipping {exp}")
            continue
            
        para_df = pd.read_csv(para_file, index_col=0)
        
        for param in params:
            # Check if parameter exists in the file
            if param in para_df.index:
                param_data[param]['mean'].append(para_df.loc[param, 'mean'])
                param_data[param]['sd'].append(para_df.loc[param, 'sd'])
                param_data[param]['exp'].append(exp)
            # Handle subject-level parameters (average across subjects)
            elif any(param in idx for idx in para_df.index):
                # Get all rows matching the parameter pattern (e.g., 'var_s[0]', 'var_s[1]', ...)
                matching_rows = [idx for idx in para_df.index if idx.startswith(f'{param}[')]
                if matching_rows:
                    # Average across subjects
                    mean_val = para_df.loc[matching_rows, 'mean'].mean()
                    sd_val = para_df.loc[matching_rows, 'sd'].mean()
                    param_data[param]['mean'].append(mean_val)
                    param_data[param]['sd'].append(sd_val)
                    param_data[param]['exp'].append(exp)
                else:
                    # Parameter not estimated in this model/experiment
                    param_data[param]['mean'].append(0)
                    param_data[param]['sd'].append(0)
                    param_data[param]['exp'].append(exp)
            else:
                # Parameter not estimated in this model/experiment
                param_data[param]['mean'].append(0)
                param_data[param]['sd'].append(0)
                param_data[param]['exp'].append(exp)
    
    # Create subplots
    n_params = len(params)
    fig, axes = plt.subplots(1, n_params, figsize=(5*n_params, 4))
    
    if n_params == 1:
        axes = [axes]
    
    for i, param in enumerate(params):
        ax = axes[i]
        
        # Create bar plot with error bars
        x_pos = np.arange(len(param_data[param]['exp']))
        ax.bar(x_pos, param_data[param]['mean'], yerr=param_data[param]['sd'], 
               color='gray', alpha=0.7, capsize=5)
        
        ax.set_xticks(x_pos)
        ax.set_xticklabels(param_data[param]['exp'], rotation=45, ha='right')
        ax.set_ylabel('Mean')
        ax.set_title(f'${param.replace("_", "")}$')
        ax.axhline(y=0, color='black', linestyle='--', linewidth=0.5)
    
    plt.suptitle(f'Parameter Estimates: {model_name}', fontsize=14, y=1.02)
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, f'{model_name}_parameters.png'), dpi=300, bbox_inches='tight')
    plt.show()

# Visualize FreeParameters model
visualize_model_parameters('FreeParameters', params=['k_s', 'l_s', 'k_r'])

1.9.1 Experimentwise model parameters

Code
# visualize Experimentwise model
visualize_model_parameters('Experimentwise', params=['k_s', 'l_s', 'k_r'])

1.9.2 Encoding only and Reproduction only parameters

Code
visualize_model_parameters('EncodingOnly')
visualize_model_parameters('ReproductionOnly')

1.10 Model Fitting Visualization

Code
def visualize_model_fit(model_name, experiments=None):
    """
    Visualize model fit: observed vs predicted reproduction biases.
    
    Parameters:
    -----------
    model_name : str
        Name of the model (e.g., 'FreeParameters', 'Experimentwise')
    experiments : list
        List of experiments to visualize (default: all 5 experiments)
    """
    if experiments is None:
        experiments = ['Encoding', 'Reproduction', 'Baseline', 'Both', 'Both_gap']
    
    # Collect all data into a single DataFrame
    all_data = []
    
    for exp in experiments:
        mdat_file = os.path.join(cpath, 'data', model_name, f'{exp}_mdat.csv')
        
        if not os.path.exists(mdat_file):
            continue
        
        mdat = pd.read_csv(mdat_file)
        
        # Add experiment column
        mdat['Experiment'] = exp
        
        # For Both_gap, merge gap information
        if exp == 'Both_gap':
            exp_data = expData[expData['ExpName'] == exp].copy()
            # Create gap label
            mdat['gap_label'] = mdat['gap'].map({mdat['gap'].min(): 'short', mdat['gap'].max(): 'long'})
        else:
            mdat['gap_label'] = 'none'
        
        all_data.append(mdat)
    
    # Combine all data
    combined_data = pd.concat(all_data, ignore_index=True)

    # Aggregate: compute mean and SEM across subjects
    agg_data = combined_data.groupby(['Experiment', 'WMSize', 'curDur', 'gap_label']).agg({
        'repErr': ['mean', 'sem'],
        'predErr': 'mean',
        'repCV': ['mean', 'sem'],
        'predCV': 'mean'
    }).reset_index()

    # Flatten column names
    agg_data.columns = ['Experiment', 'WMSize', 'curDur', 'gap_label',
                        'repErr_mean', 'repErr_sem', 'predErr',
                        'repCV_mean', 'repCV_sem', 'predCV']

    # Create figure - 3 rows layout matching template
    # Note: Row 1 & 2 share duration x-axis, Row 3 has bias x-axis (don't share)
    subplot_width = 2.5
    subplot_height = 2.2
    fig, axes = plt.subplots(3, len(experiments),
                            figsize=(subplot_width*len(experiments), subplot_height*3))
    if len(experiments) == 1:
        axes = axes.reshape(-1, 1)

    # Color palette for memory loads
    colors = ['#d9d9d9', '#838383', '#3b3b3b']
    
    for idx, exp in enumerate(experiments):
        exp_data = agg_data[agg_data['Experiment'] == exp]

        # Get axes for this experiment (column)
        ax_bias = axes[0, idx]
        ax_cv = axes[1, idx]
        ax_scatter = axes[2, idx]

        if len(exp_data) == 0:
            for ax in [ax_bias, ax_cv, ax_scatter]:
                ax.text(0.5, 0.5, f'No data\nfor {exp}', ha='center', va='center', transform=ax.transAxes)
            ax_bias.set_title(exp, fontsize=10)
            continue

        # === ROW 1: Bias plots ===
        if exp == 'Both_gap':
            for gap_label in ['short', 'long']:
                gap_data = exp_data[exp_data['gap_label'] == gap_label]
                linestyle = '-' if gap_label == 'short' else '--'
                marker = 'o' if gap_label == 'short' else 's'

                for wm_idx, wm in enumerate(sorted(gap_data['WMSize'].unique())):
                    wm_data = gap_data[gap_data['WMSize'] == wm].sort_values('curDur')

                    # Observed with error bars
                    ax_bias.errorbar(wm_data['curDur'], wm_data['repErr_mean'],
                                    yerr=wm_data['repErr_sem'],
                                    fmt=marker, color=colors[wm_idx], markersize=5,
                                    alpha=0.6, capsize=2, zorder=3)
                    # Predicted
                    ax_bias.plot(wm_data['curDur'], wm_data['predErr'],
                                color=colors[wm_idx], linestyle=linestyle, linewidth=1.5, zorder=2)
        else:
            for wm_idx, wm in enumerate(sorted(exp_data['WMSize'].unique())):
                wm_data = exp_data[exp_data['WMSize'] == wm].sort_values('curDur')

                # Observed with error bars
                ax_bias.errorbar(wm_data['curDur'], wm_data['repErr_mean'],
                                yerr=wm_data['repErr_sem'],
                                fmt='o', color=colors[wm_idx], markersize=5,
                                alpha=0.6, capsize=2, zorder=3)
                # Predicted
                ax_bias.plot(wm_data['curDur'], wm_data['predErr'],
                            color=colors[wm_idx], linestyle='-', linewidth=1.5, zorder=2)

        ax_bias.axhline(y=0, color='black', linestyle='--', linewidth=0.5, zorder=1)
        ax_bias.set_title(exp, fontsize=10)
        ax_bias.set_xlim(0.3, 1.8)
        ax_bias.tick_params(labelsize=7)
        if idx == 0:
            ax_bias.set_ylabel('Reproduction bias (s)', fontsize=8)

        # === ROW 2: CV plots ===
        if exp == 'Both_gap':
            for gap_label in ['short', 'long']:
                gap_data = exp_data[exp_data['gap_label'] == gap_label]
                linestyle = '-' if gap_label == 'short' else '--'
                marker = 'o' if gap_label == 'short' else 's'

                for wm_idx, wm in enumerate(sorted(gap_data['WMSize'].unique())):
                    wm_data = gap_data[gap_data['WMSize'] == wm].sort_values('curDur')

                    # Observed CV with error bars
                    ax_cv.errorbar(wm_data['curDur'], wm_data['repCV_mean'],
                                  yerr=wm_data['repCV_sem'],
                                  fmt=marker, color=colors[wm_idx], markersize=5,
                                  alpha=0.6, capsize=2, zorder=3)
                    # Predicted CV
                    ax_cv.plot(wm_data['curDur'], wm_data['predCV'],
                              color=colors[wm_idx], linestyle=linestyle, linewidth=1.5, zorder=2)
        else:
            for wm_idx, wm in enumerate(sorted(exp_data['WMSize'].unique())):
                wm_data = exp_data[exp_data['WMSize'] == wm].sort_values('curDur')

                # Observed CV with error bars
                ax_cv.errorbar(wm_data['curDur'], wm_data['repCV_mean'],
                              yerr=wm_data['repCV_sem'],
                              fmt='o', color=colors[wm_idx], markersize=5,
                              alpha=0.6, capsize=2, zorder=3)
                # Predicted CV
                ax_cv.plot(wm_data['curDur'], wm_data['predCV'],
                          color=colors[wm_idx], linestyle='-', linewidth=1.5, zorder=2)

        ax_cv.set_xlim(0.3, 1.8)
        ax_cv.tick_params(labelsize=7)
        if idx == 0:
            ax_cv.set_ylabel('Coefficient of Variation', fontsize=8)

        # === ROW 3: Scatter plot (individual data points) ===
        exp_raw_data = combined_data[combined_data['Experiment'] == exp]

        if exp == 'Both_gap':
            for gap_label in ['short', 'long']:
                gap_data = exp_raw_data[exp_raw_data['gap_label'] == gap_label]
                marker = 'o' if gap_label == 'short' else 's'

                for wm_idx, wm in enumerate(sorted(gap_data['WMSize'].unique())):
                    wm_data = gap_data[gap_data['WMSize'] == wm]
                    ax_scatter.scatter(wm_data['repErr'], wm_data['predErr'],
                                      color=colors[wm_idx], marker=marker,
                                      s=15, alpha=0.5, edgecolors='none')
        else:
            for wm_idx, wm in enumerate(sorted(exp_raw_data['WMSize'].unique())):
                wm_data = exp_raw_data[exp_raw_data['WMSize'] == wm]
                ax_scatter.scatter(wm_data['repErr'], wm_data['predErr'],
                                  color=colors[wm_idx], marker='o',
                                  s=15, alpha=0.5, edgecolors='none')

        # Add diagonal line and set axis limits centered around 0
        # Determine appropriate limits based on data
        all_obs = exp_raw_data['repErr'].values
        all_pred = exp_raw_data['predErr'].values
        combined_vals = np.concatenate([all_obs, all_pred])
        max_abs = max(abs(combined_vals.min()), abs(combined_vals.max()))
        lim = max_abs * 1.1  # Add 10% margin

        ax_scatter.plot([-lim, lim], [-lim, lim], 'k--', linewidth=0.5, zorder=1)
        ax_scatter.axhline(y=0, color='gray', linestyle=':', linewidth=0.5, alpha=0.5)
        ax_scatter.axvline(x=0, color='gray', linestyle=':', linewidth=0.5, alpha=0.5)
        ax_scatter.set_xlim(-lim, lim)
        ax_scatter.set_ylim(-lim, lim)
        ax_scatter.set_xlabel('Observed Bias (s)', fontsize=8)
        ax_scatter.tick_params(labelsize=7)
        ax_scatter.set_aspect('equal', adjustable='box')
        if idx == 0:
            ax_scatter.set_ylabel('Predicted Bias (s)', fontsize=8)
            ax_scatter.text(-0.35, 0.5, 'c', transform=ax_scatter.transAxes,
                           fontsize=12, fontweight='bold', va='center')

        # Add row labels on the left
        if idx == 0:
            ax_bias.text(-0.35, 0.5, 'a', transform=ax_bias.transAxes,
                        fontsize=12, fontweight='bold', va='center')
            ax_cv.text(-0.35, 0.5, 'b', transform=ax_cv.transAxes,
                      fontsize=12, fontweight='bold', va='center')

    # Legend - placed at bottom for better layout
    from matplotlib.lines import Line2D

    # Memory load legend (always present)
    legend_elements_load = [
        Line2D([0], [0], color=colors[0], lw=2, label='low'),
        Line2D([0], [0], color=colors[1], lw=2, label='medium'),
        Line2D([0], [0], color=colors[2], lw=2, label='high')
    ]

    if 'Both_gap' in experiments:
        # Gap legend (only for Both_gap experiment)
        legend_elements_gap = [
            Line2D([0], [0], color='gray', lw=2, linestyle='-', marker='o', markersize=5, label='short'),
            Line2D([0], [0], color='gray', lw=2, linestyle='--', marker='s', markersize=5, label='long')
        ]

        # Create two legends side by side at the bottom
        leg1 = fig.legend(handles=legend_elements_load, loc='lower center',
                         bbox_to_anchor=(0.35, -0.02), frameon=False,
                         title='Memory Load', ncol=3, fontsize=8, title_fontsize=9)
        leg2 = fig.legend(handles=legend_elements_gap, loc='lower center',
                         bbox_to_anchor=(0.68, -0.02), frameon=False,
                         title='Gap', ncol=2, fontsize=8, title_fontsize=9)
        fig.add_artist(leg1)  # Add first legend back since second overwrites it
    else:
        # Single legend for memory load only
        fig.legend(handles=legend_elements_load, loc='lower center',
                  bbox_to_anchor=(0.5, -0.02), frameon=False,
                  title='Memory Load', ncol=3, fontsize=8, title_fontsize=9)

    plt.tight_layout()
    plt.subplots_adjust(bottom=0.06, hspace=0.3)  # Adjust spacing for 3 rows
    plt.savefig(os.path.join(OUTPUT_PATH, f'{model_name}_model_fit.png'), dpi=300, bbox_inches='tight')
    plt.show()

# Visualize FreeParameters model fit
visualize_model_fit('FreeParameters')

1.10.1 Other Models

Code
visualize_model_fit('Experimentwise')
visualize_model_fit('EncodingOnly')
visualize_model_fit('ReproductionOnly')